import os
import json
import torch
import tqdm
import hydra
import hydra.utils as hu
from transformers import set_seed
from torch.utils.data import DataLoader
from src.utils.collators import DataCollatorWithPaddingAndCuda
from src.models.biencoder import BiEncoder

def submodular_diverse_select_gpu(
    demo_embeds: torch.Tensor,    # (n, d) on GPU
    test_embed:   torch.Tensor,    # (d,)   on GPU
    k:            int,
    lambd:        float
) -> list:
    """
    Greedy submodular selection on GPU, using Sherman-Morrison to update inv(V_S).
    """
    device = demo_embeds.device
    n, d = demo_embeds.shape

    # Initialize V_S = 0.02 * I  => invV_S = (1/0.02) * I
    invV_S = torch.eye(d, device=device) * (1.0 / 0.02)

    # Mask of candidates still available
    candidate_mask = torch.ones(n, dtype=torch.bool, device=device)

    selected = []
    for _ in range(k):
        # VS_inv_demo:  d×n  =  invV_S (d×d) @ demo_embeds.T (d×n)
        VS_inv_demo = invV_S @ demo_embeds.T
        
        # numerator_i = (test_embed · (invV_S @ x_i))²
        # => vector of size n
        scores_num = (test_embed @ VS_inv_demo).pow(2)    # (n,)

        # denom_i = 1 + x_iᵀ @ invV_S @ x_i
        # = 1 + (demo_embeds * VS_inv_demo.T).sum(dim=1)
        denom = 1.0 + (demo_embeds * VS_inv_demo.T).sum(dim=1)  # (n,)

        # full submodular score
        scores = scores_num / denom + lambd * denom

        # mask out already-chosen candidates
        scores = scores.masked_fill(~candidate_mask, float('-inf'))

        # pick best
        best_idx = torch.argmax(scores).item()
        selected.append(best_idx)
        candidate_mask[best_idx] = False

        # Sherman–Morrison update of invV_S:
        #  inv(V_S + x xᵀ) = invV_S - (invV_S x xᵀ invV_S) / (1 + xᵀ invV_S x)
        x = demo_embeds[best_idx].unsqueeze(1)  # (d,1)
        invV_S_x = invV_S @ x                  # (d,1)
        denom_scalar = (x.T @ invV_S_x).item() + 1.0
        invV_S = invV_S - (invV_S_x @ invV_S_x.T) / denom_scalar

    return selected


@hydra.main(config_path="configs", config_name="submodular_retriever")
def main(cfg):
    # reproducibility + cudnn tuning
    set_seed(getattr(cfg, 'seed', 42))
    torch.backends.cudnn.benchmark = True

    # --- Model setup ---
    model_cfg = hu.instantiate(cfg.model_config)
    if cfg.pretrained_model_path:
        print(f"Loading model from: {cfg.pretrained_model_path}")
        model = BiEncoder.from_pretrained(cfg.pretrained_model_path, config=model_cfg)
    else:
        model = BiEncoder(model_cfg)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    # --- DataLoaders with pinned memory & workers ---
    index_reader = hu.instantiate(cfg.index_reader)
    collator = DataCollatorWithPaddingAndCuda(
        tokenizer=index_reader.tokenizer,
        device=device
    )
    index_loader = DataLoader(
        index_reader,
        batch_size=cfg.batch_size,
        collate_fn=collator
    )

    # --- Build index embeddings ON GPU ---
    index_embeds_list = []
    index_metadata   = []
    for batch in tqdm.tqdm(index_loader, desc="Encoding index passages"):
        with torch.no_grad():
            embeds = model.encode(batch["input_ids"], batch["attention_mask"], encode_ctx=True)
        index_embeds_list.append(embeds)  # still on GPU
        index_metadata.extend(batch.get("metadata", []).data)

    demo_embeds = torch.cat(index_embeds_list, dim=0)  # (n, d) on GPU

    # --- Build query embeddings ON GPU ---
    query_reader = hu.instantiate(cfg.dataset_reader)
    query_loader = DataLoader(
        query_reader,
        batch_size=cfg.batch_size,
        collate_fn=collator
    )

    query_embeds_list = []
    query_metadata    = []
    for batch in tqdm.tqdm(query_loader, desc="Encoding queries"):
        with torch.no_grad():
            embeds = model.encode(batch["input_ids"], batch["attention_mask"])
        query_embeds_list.append(embeds)
        query_metadata.extend(batch.get("metadata", []).data)

    query_embeds = torch.cat(query_embeds_list, dim=0)  # (m, d) on GPU

    # --- Submodular-selection loop (now fast on GPU) ---
    results = []
    for idx, (q_embed, meta) in tqdm.tqdm(
            enumerate(zip(query_embeds, query_metadata)),
            total=len(query_embeds),
            desc="Selecting contexts"
    ):
        if cfg.run_for_n_samples and idx >= cfg.run_for_n_samples:
            break

        selected_idxs = submodular_diverse_select_gpu(
            demo_embeds=demo_embeds,
            test_embed=q_embed,
            k=cfg.num_ice,
            lambd=cfg.lambd
        )

        orig = query_reader.dataset_wrapper[meta["id"]].copy()
        orig["ctxs"]            = selected_idxs
        orig["ctxs_candidates"] = [[i] for i in selected_idxs]
        results.append(orig)

    # --- Write out JSON ---
    os.makedirs(os.path.dirname(cfg.output_file), exist_ok=True)
    with open(cfg.output_file, "w") as fout:
        json.dump(results, fout, indent=2)


if __name__ == "__main__":
    main()



